import torch
from torch import nn
import torch.nn.functional as F
from ..utils import create_activation, create_norm

class MeanAct(nn.Module):
    """Mean activation class."""

    def __init__(self, softmax, standardscale):
        super().__init__()
        self.standardscale = standardscale
        self.softmax = softmax

    def forward(self, x):
        if not self.softmax:
            return torch.clamp(torch.exp(x), min=1e-5, max=1e6)
        else:
            return torch.softmax(x, 1) * self.standardscale

class DispAct(nn.Module):
    """Dispersion activation class."""

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.clamp(F.softplus(x), min=1e-4, max=1e4)

class ZINB(nn.Module):
    """ZINB Decoder.
    Parameters
    ----------
    input_dim : int
        dimension of input feature.
    n_z : int
        dimension of latent embedding.
    n_dec_1 : int optional
        number of nodes of decoder layer 1.
    n_dec_2 : int optional
        number of nodes of decoder layer 2.
    """

    def __init__(self, hidden_dim, out_dim, standardscale=1e4, n_dec_1=256, softmax=True, disp='gene_cell'):
        super().__init__()
        self.dec_1 = nn.Linear(hidden_dim, n_dec_1)
        self.dec_mean = nn.Sequential(nn.Linear(n_dec_1, self.input_dim), MeanAct(softmax, standardscale))
        self.dec_pi = nn.Sequential(nn.Linear(n_dec_1, self.input_dim), nn.Sigmoid())
        self.disp = disp
        if disp == 'gene':
            self.dec_disp = nn.Parameter(torch.ones(out_dim))
        else:
            self.dec_disp = nn.Sequential(nn.Linear(n_dec_1, out_dim), DispAct())

    def forward(self, z):
        """Forward propagation.
        Parameters
        ----------
        z :
            embedding.
        Returns
        -------
        _mean :
            data mean from ZINB.
        _disp :
            data dispersion from ZINB.
        _pi :
            data dropout probability from ZINB4
        """

        h = F.relu(self.dec_1(z))
        _mean = self.dec_mean(h)
        if self.disp == 'gene':
            _disp = self.dec_disp.repeat(z.shape[0], 1)
        else:
            _disp = self.dec_disp(h)
        _pi = self.dec_pi(h)
        return _mean, _disp, _pi

class ZINBResMLPDecoder(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers, dropout, norm, batch_num, standardscale=1e4):
        super().__init__()
        self.layers = nn.ModuleList()
        assert num_layers > 1, 'At least two layer for MLPs.'
        for i in range(num_layers - 1):
            dim = hidden_dim if i>0 else in_dim
            self.layers.append(nn.Sequential(
                nn.Linear(dim, hidden_dim),
                nn.PReLU(),
                nn.Dropout(dropout),
                create_norm(norm, hidden_dim)
            ))
        self.out_layer = ZINB(
            hidden_dim * (num_layers - 1), out_dim,
            standardscale
        )
        self.batch_emb = nn.Embedding(batch_num, hidden_dim)

    def forward(self, x_dict):
        batch_labels = x_dict['batch']
        x = x_dict['h']
        hist = []
        for layer in self.layers:
            x = layer(x)
            x = x + self.batch_emb(batch_labels)
            hist.append(x)
        mean, disp, pi = self.out_layer(torch.cat(hist, 1))
        return {'mean': mean, 'disp': disp, 'pi': pi, 'recon': torch.log((1 - pi) * mean +1), 'latent': x_dict['h']}

class ZINBMLPDecoder(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers, dropout, norm, batch_num, standardscale=1e4):
        super().__init__()
        self.layers = nn.ModuleList()
        assert num_layers > 1, 'At least two layer for MLPs.'
        for i in range(num_layers-1):
            dim = hidden_dim if i > 0 else in_dim
            self.layers.append(nn.Sequential(
                nn.Linear(dim, hidden_dim),
                nn.PReLU(),
                nn.Dropout(dropout),
                create_norm(norm, hidden_dim)
            ))
        self.out_layer = ZINB(
            hidden_dim * (num_layers - 1), out_dim,
            standardscale
        )
        self.batch_emb = nn.Embedding(batch_num, hidden_dim)

    def forward(self, x_dict):
        batch_labels = x_dict['batch']
        x = x_dict['h']
        for layer in self.layers:
            x = layer(x)
        x = x + self.batch_emb(batch_labels)
        mean, disp, pi = self.out_layer(x)
        return {'mean': mean, 'disp': disp, 'pi': pi, 'recon': torch.log((1 - pi) * mean + 1), 'latent': x_dict['h']}